{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference Example with Medusa\n",
    "\n",
    "In this Jupyter notebook, we're going to demonstrate how to perform inference using the Medusa model on an interesting story prompt. Let's get the ball rolling!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tianle/anaconda3/envs/medusa_v1.0/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\" # define GPU id, remove if you want to use all GPUs available\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "from contextlib import contextmanager\n",
    "import numpy as np\n",
    "from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM\n",
    "from medusa.model.medusa_model import MedusaModel\n",
    "from medusa.model.kv_cache import *\n",
    "from medusa.model.utils import *\n",
    "from medusa.model.medusa_choices import *\n",
    "import transformers\n",
    "from huggingface_hub import hf_hub_download"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Medusa Forward Function\n",
    "\n",
    "We define the medusa_forward function that will be used for generating stories based on the provided prompts.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@contextmanager\n",
    "def timed(wall_times, key):\n",
    "    start = time.time()\n",
    "    torch.cuda.synchronize()\n",
    "    yield\n",
    "    torch.cuda.synchronize()\n",
    "    end = time.time()\n",
    "    elapsed_time = end - start\n",
    "    wall_times[key].append(elapsed_time)\n",
    "\n",
    "def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):\n",
    "    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}\n",
    "    \n",
    "    with timed(wall_times, 'init'):\n",
    "        if hasattr(model, \"medusa_choices\") and model.medusa_choices == medusa_choices:\n",
    "            # Load the cached medusa buffer\n",
    "            medusa_buffers = model.medusa_buffers\n",
    "        else:\n",
    "            # Initialize the medusa buffer\n",
    "            medusa_buffers = generate_medusa_buffers(\n",
    "                medusa_choices, device=model.base_model.device\n",
    "            )\n",
    "        model.medusa_buffers = medusa_buffers\n",
    "        model.medusa_choices = medusa_choices\n",
    "\n",
    "        # Initialize the past key and value states\n",
    "        if hasattr(model, \"past_key_values\"):\n",
    "            past_key_values = model.past_key_values\n",
    "            past_key_values_data = model.past_key_values_data\n",
    "            current_length_data = model.current_length_data\n",
    "            # Reset the past key and value states\n",
    "            current_length_data.zero_()\n",
    "        else:\n",
    "            (\n",
    "                past_key_values,\n",
    "                past_key_values_data,\n",
    "                current_length_data,\n",
    "            ) = initialize_past_key_values(model.base_model)\n",
    "            model.past_key_values = past_key_values\n",
    "            model.past_key_values_data = past_key_values_data\n",
    "            model.current_length_data = current_length_data\n",
    "\n",
    "        input_len = input_ids.shape[1]\n",
    "        reset_medusa_mode(model)\n",
    "        medusa_logits, logits = initialize_medusa(\n",
    "                input_ids, model, medusa_buffers[\"medusa_attn_mask\"], past_key_values\n",
    "        )\n",
    "    new_token = 0\n",
    "\n",
    "    for idx in range(max_steps): \n",
    "        with timed(wall_times, 'medusa'):\n",
    "            candidates, tree_candidates = generate_candidates(\n",
    "                    medusa_logits,\n",
    "                    logits,\n",
    "                    medusa_buffers[\"tree_indices\"],\n",
    "                    medusa_buffers[\"retrieve_indices\"],\n",
    "                )\n",
    "\n",
    "        with timed(wall_times, 'tree'):\n",
    "            medusa_logits, logits, outputs = tree_decoding(\n",
    "                    model,\n",
    "                    tree_candidates,\n",
    "                    past_key_values,\n",
    "                    medusa_buffers[\"medusa_position_ids\"],\n",
    "                    input_ids,\n",
    "                    medusa_buffers[\"retrieve_indices\"],\n",
    "                )\n",
    "\n",
    "        with timed(wall_times, 'posterior'):\n",
    "            best_candidate, accept_length = evaluate_posterior(\n",
    "                    logits, candidates, temperature, posterior_threshold, posterior_alpha\n",
    "                )\n",
    "        \n",
    "        with timed(wall_times, 'update'):\n",
    "            input_ids, logits, medusa_logits, new_token = update_inference_inputs(\n",
    "                    input_ids,\n",
    "                    candidates,\n",
    "                    best_candidate,\n",
    "                    accept_length,\n",
    "                    medusa_buffers[\"retrieve_indices\"],\n",
    "                    outputs,\n",
    "                    logits,\n",
    "                    medusa_logits,\n",
    "                    new_token,\n",
    "                    past_key_values_data,\n",
    "                    current_length_data,\n",
    "                )\n",
    "\n",
    "        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():\n",
    "            break\n",
    "\n",
    "    return input_ids, new_token, idx, wall_times\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Loading\n",
    "\n",
    "We load the model and tokenizer using the specified paths and configurations.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Unrecognized model in FasterDecoding/medusa-vicuna-7b-v1.3. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, audio-spectrogram-transformer, autoformer, bark, bart, beit, bert, bert-generation, big_bird, bigbird_pegasus, biogpt, bit, blenderbot, blenderbot-small, blip, blip-2, bloom, bridgetower, bros, camembert, canine, chinese_clip, clap, clip, clipseg, code_llama, codegen, conditional_detr, convbert, convnext, convnextv2, cpmant, ctrl, cvt, data2vec-audio, data2vec-text, data2vec-vision, deberta, deberta-v2, decision_transformer, deformable_detr, deit, deta, detr, dinat, dinov2, distilbert, donut-swin, dpr, dpt, efficientformer, efficientnet, electra, encodec, encoder-decoder, ernie, ernie_m, esm, falcon, flaubert, flava, fnet, focalnet, fsmt, funnel, git, glpn, gpt-sw3, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, gptsan-japanese, graphormer, groupvit, hubert, ibert, idefics, imagegpt, informer, instructblip, jukebox, layoutlm, layoutlmv2, layoutlmv3, led, levit, lilt, llama, longformer, longt5, luke, lxmert, m2m_100, marian, markuplm, mask2former, maskformer, maskformer-swin, mbart, mctct, mega, megatron-bert, mgp-str, mistral, mobilebert, mobilenet_v1, mobilenet_v2, mobilevit, mobilevitv2, mpnet, mpt, mra, mt5, musicgen, mvp, nat, nezha, nllb-moe, nougat, nystromformer, oneformer, open-llama, openai-gpt, opt, owlvit, pegasus, pegasus_x, perceiver, persimmon, pix2struct, plbart, poolformer, pop2piano, prophetnet, pvt, qdqbert, rag, realm, reformer, regnet, rembert, resnet, retribert, roberta, roberta-prelayernorm, roc_bert, roformer, rwkv, sam, segformer, sew, sew-d, speech-encoder-decoder, speech_to_text, speech_to_text_2, speecht5, splinter, squeezebert, swiftformer, swin, swin2sr, swinv2, switch_transformers, t5, table-transformer, tapas, time_series_transformer, timesformer, timm_backbone, trajectory_transformer, transfo-xl, trocr, tvlt, umt5, unispeech, unispeech-sat, upernet, van, videomae, vilt, vision-encoder-decoder, vision-text-dual-encoder, visual_bert, vit, vit_hybrid, vit_mae, vit_msn, vitdet, vitmatte, vits, vivit, wav2vec2, wav2vec2-conformer, wavlm, whisper, xclip, xglm, xlm, xlm-prophetnet, xlm-roberta, xlm-roberta-xl, xlnet, xmod, yolos, yoso",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb Cell 6\u001b[0m line \u001b[0;36m2\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m model_name \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mFasterDecoding/medusa-vicuna-7b-v1.3\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m model \u001b[39m=\u001b[39m MedusaModel\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m     model_name,\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m     medusa_num_heads \u001b[39m=\u001b[39;49m \u001b[39m4\u001b[39;49m,\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m     torch_dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mfloat16,\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>\u001b[0m     low_cpu_mem_usage\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>\u001b[0m     device_map\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mauto\u001b[39;49m\u001b[39m\"\u001b[39;49m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>\u001b[0m )\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m tokenizer \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mget_tokenizer()\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Btogether2/home/tianle/Medusa-v1.0-prerelease/notebooks/medusa_inference_explained.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m medusa_choices \u001b[39m=\u001b[39m mc_sim_7b_63\n",
      "File \u001b[0;32m~/Medusa-v1.0-prerelease/medusa/model/medusa_model.py:117\u001b[0m, in \u001b[0;36mMedusaLlamaModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *args, **kwargs)\u001b[0m\n\u001b[1;32m    109\u001b[0m \u001b[39m@classmethod\u001b[39m\n\u001b[1;32m    110\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfrom_pretrained\u001b[39m(\n\u001b[1;32m    111\u001b[0m     \u001b[39mcls\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    115\u001b[0m ):\n\u001b[1;32m    116\u001b[0m     \u001b[39m# Manually load config to ensure that the medusa_num_heads parameter is loaded\u001b[39;00m\n\u001b[0;32m--> 117\u001b[0m     config \u001b[39m=\u001b[39m AutoConfig\u001b[39m.\u001b[39;49mfrom_pretrained(pretrained_model_name_or_path)\n\u001b[1;32m    118\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39mfrom_pretrained(\n\u001b[1;32m    119\u001b[0m         pretrained_model_name_or_path,\n\u001b[1;32m    120\u001b[0m         \u001b[39m*\u001b[39margs,\n\u001b[1;32m    121\u001b[0m         \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m    122\u001b[0m         config\u001b[39m=\u001b[39mconfig,\n\u001b[1;32m    123\u001b[0m     )\n",
      "File \u001b[0;32m~/anaconda3/envs/medusa_v1.0/lib/python3.9/site-packages/transformers/models/auto/configuration_auto.py:1059\u001b[0m, in \u001b[0;36mAutoConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m   1056\u001b[0m         \u001b[39mif\u001b[39;00m pattern \u001b[39min\u001b[39;00m \u001b[39mstr\u001b[39m(pretrained_model_name_or_path):\n\u001b[1;32m   1057\u001b[0m             \u001b[39mreturn\u001b[39;00m CONFIG_MAPPING[pattern]\u001b[39m.\u001b[39mfrom_dict(config_dict, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39munused_kwargs)\n\u001b[0;32m-> 1059\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m   1060\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnrecognized model in \u001b[39m\u001b[39m{\u001b[39;00mpretrained_model_name_or_path\u001b[39m}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1061\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mShould have a `model_type` key in its \u001b[39m\u001b[39m{\u001b[39;00mCONFIG_NAME\u001b[39m}\u001b[39;00m\u001b[39m, or contain one of the following strings \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1062\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39min its name: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(CONFIG_MAPPING\u001b[39m.\u001b[39mkeys())\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1063\u001b[0m )\n",
      "\u001b[0;31mValueError\u001b[0m: Unrecognized model in FasterDecoding/medusa-vicuna-7b-v1.3. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, audio-spectrogram-transformer, autoformer, bark, bart, beit, bert, bert-generation, big_bird, bigbird_pegasus, biogpt, bit, blenderbot, blenderbot-small, blip, blip-2, bloom, bridgetower, bros, camembert, canine, chinese_clip, clap, clip, clipseg, code_llama, codegen, conditional_detr, convbert, convnext, convnextv2, cpmant, ctrl, cvt, data2vec-audio, data2vec-text, data2vec-vision, deberta, deberta-v2, decision_transformer, deformable_detr, deit, deta, detr, dinat, dinov2, distilbert, donut-swin, dpr, dpt, efficientformer, efficientnet, electra, encodec, encoder-decoder, ernie, ernie_m, esm, falcon, flaubert, flava, fnet, focalnet, fsmt, funnel, git, glpn, gpt-sw3, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, gptsan-japanese, graphormer, groupvit, hubert, ibert, idefics, imagegpt, informer, instructblip, jukebox, layoutlm, layoutlmv2, layoutlmv3, led, levit, lilt, llama, longformer, longt5, luke, lxmert, m2m_100, marian, markuplm, mask2former, maskformer, maskformer-swin, mbart, mctct, mega, megatron-bert, mgp-str, mistral, mobilebert, mobilenet_v1, mobilenet_v2, mobilevit, mobilevitv2, mpnet, mpt, mra, mt5, musicgen, mvp, nat, nezha, nllb-moe, nougat, nystromformer, oneformer, open-llama, openai-gpt, opt, owlvit, pegasus, pegasus_x, perceiver, persimmon, pix2struct, plbart, poolformer, pop2piano, prophetnet, pvt, qdqbert, rag, realm, reformer, regnet, rembert, resnet, retribert, roberta, roberta-prelayernorm, roc_bert, roformer, rwkv, sam, segformer, sew, sew-d, speech-encoder-decoder, speech_to_text, speech_to_text_2, speecht5, splinter, squeezebert, swiftformer, swin, swin2sr, swinv2, switch_transformers, t5, table-transformer, tapas, time_series_transformer, timesformer, timm_backbone, trajectory_transformer, transfo-xl, trocr, tvlt, umt5, unispeech, unispeech-sat, upernet, van, videomae, vilt, vision-encoder-decoder, vision-text-dual-encoder, visual_bert, vit, vit_hybrid, vit_mae, vit_msn, vitdet, vitmatte, vits, vivit, wav2vec2, wav2vec2-conformer, wavlm, whisper, xclip, xglm, xlm, xlm-prophetnet, xlm-roberta, xlm-roberta-xl, xlnet, xmod, yolos, yoso"
     ]
    }
   ],
   "source": [
    "model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'\n",
    "model = MedusaModel.from_pretrained(\n",
    "    model_name,\n",
    "    medusa_num_heads = 4,\n",
    "    torch_dtype=torch.float16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\"\n",
    ")\n",
    "tokenizer = model.get_tokenizer()\n",
    "\n",
    "medusa_choices = mc_sim_7b_63\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Inference Parameters\n",
    "\n",
    "Next, we set some parameters that will be used during the inference process.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 0.\n",
    "posterior_threshold = 0.09\n",
    "posterior_alpha = 0.3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting The Prompt\n",
    "\n",
    "The following is the story prompt we will use for generating our story in the demo.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi, could you share a tale about a charming llama that grows Medusa-like hair and starts its own coffee shop? ASSISTANT:\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performing Inference\n",
    "\n",
    "Using the set parameters and the defined function, let's generate our story!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output length: 403\n",
      "Compression ratio: tensor(2.4724, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "with torch.inference_mode():\n",
    "    input_ids = tokenizer([prompt]).input_ids\n",
    "    output_ids, new_token, idx, wall_time = medusa_forward(\n",
    "                    torch.as_tensor(input_ids).cuda(),\n",
    "                    model,\n",
    "                    tokenizer,\n",
    "                    medusa_choices,\n",
    "                    temperature,\n",
    "                    posterior_threshold,\n",
    "                    posterior_alpha,\n",
    "                )\n",
    "    output_ids = output_ids[0][len(input_ids[0]) :]\n",
    "    print(\"Output length:\", output_ids.size(-1))\n",
    "    print(\"Compression ratio:\", new_token / idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoding The Output\n",
    "\n",
    "Let's decode the generated output to obtain our story.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Once upon a time, in a small village nestled in the Andes mountains, there lived a charming llama named Luna. Luna was known for her kind heart and her love of coffee. She would often spend her afternoons sipping on a steaming cup of joe at the local café, chatting with the villagers and enjoying the warmth of the sun on her back.\n",
      "\n",
      "One day, as Luna was grazing on some fresh grass, she noticed that her hair was starting to grow longer and thicker. At first, she didn't think much of it, but as the days went on, her hair continued to grow and change. It became thick and wiry, with sharp spikes protruding from it.\n",
      "\n",
      "Luna was confused and a little scared by her new appearance. She had always been a gentle creature, and now she looked like a monster. She knew that she couldn't stay in the village anymore, so she set off on a journey to find a new home.\n",
      "\n",
      "As she wandered through the mountains, Luna stumbled upon a beautiful clearing. In the center of the clearing stood a small cottage, with a sign hanging outside that read \"Café Llama.\" Luna knew that this was where she belonged.\n",
      "\n",
      "She transformed the cottage into a cozy coffee shop, serving the best coffee in the mountains. The villagers were amazed by Luna's transformation, and they flocked to her café to taste her delicious brews.\n",
      "\n",
      "Luna's Medusa-like hair became her signature style, and she quickly became known as the most charming llama in the land. She spent her days sipping coffee, chatting with customers, and enjoying the warmth of the sun on her back. And she knew that she had finally found her true home.</s>\n"
     ]
    }
   ],
   "source": [
    "output = tokenizer.decode(\n",
    "                    output_ids,\n",
    "                    spaces_between_special_tokens=False,\n",
    "                )\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyzing Wall Times\n",
    "\n",
    "We will now break down and analyze the wall times during the inference process.\n",
    "\n",
    "You might notice a significant time consumption during the initialization phase. This is primarily due to the GPU cache initialization process on the first run.\n",
    "\n",
    "For a clearer perspective, you can try rerunning the decoding segment again to observe the differences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "Wall time init:                              0.026\n",
      "Wall time medusa:                            0.031\n",
      "Wall time Tree:                              3.732\n",
      "Wall time Posterior:                         0.025\n",
      "Wall time Update:                            0.051\n",
      "--------------------------------------------------\n",
      "Wall time portion medusa:                    0.008\n",
      "Wall time portion Tree:                      0.965\n",
      "Wall time portion Posterior:                 0.007\n",
      "Wall time portion Update:                    0.013\n",
      "--------------------------------------------------\n",
      "Tokens/second:                             104.247\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "max_length = 50\n",
    "\n",
    "def format_string(text, value, max_length):\n",
    "    value_str = \"{:.3f}\".format(value)\n",
    "    return f\"{text:<{max_length - len(value_str)}}{value_str}\"\n",
    "\n",
    "time_init = np.sum(wall_time['init'] )\n",
    "time_medusa = np.sum(wall_time['medusa'] )\n",
    "time_tree = np.sum(wall_time['tree'] )\n",
    "time_posterior = np.sum(wall_time['posterior'] )\n",
    "time_update = np.sum(wall_time['update'] )\n",
    "time_total = time_init + time_medusa + time_tree + time_posterior + time_update\n",
    "\n",
    "print('='*max_length)\n",
    "print(format_string(\"Wall time init: \", time_init, max_length))\n",
    "print(format_string(\"Wall time medusa: \", time_medusa, max_length))\n",
    "print(format_string(\"Wall time Tree: \", time_tree, max_length))\n",
    "print(format_string(\"Wall time Posterior: \", time_posterior, max_length))\n",
    "print(format_string(\"Wall time Update: \", time_update, max_length))\n",
    "print('-'*max_length)\n",
    "print(format_string(\"Wall time portion medusa: \", time_medusa / time_total, max_length))\n",
    "print(format_string(\"Wall time portion Tree: \", time_tree / time_total, max_length))\n",
    "print(format_string(\"Wall time portion Posterior: \", time_posterior / time_total, max_length))\n",
    "print(format_string(\"Wall time portion Update: \", time_update / time_total, max_length))\n",
    "print('-'*max_length)\n",
    "print(format_string(\"Tokens/second: \", new_token / time_total, max_length))\n",
    "print('='*max_length)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "medusa",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
